{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "# 概述\n",
    "\n",
    "前几节我们尝试使用与房价预测相同的简单神经网络解决手写数字识别问题，但是效果并不理想。原因是手写数字识别的输入是28 × 28的像素值，输出是0-9的数字标签，而线性回归模型无法捕捉二维图像数据中蕴含的复杂信息，如 **图1** 所示。无论是牛顿第二定律任务，还是房价预测任务，输入特征和输出预测值之间的关系均可以使用“直线”刻画（使用线性方程来表达）。但手写数字识别任务的输入像素和输出数字标签之间的关系显然不是线性的，甚至这个关系复杂到我们靠人脑难以直观理解的程度。\n",
    "\n",
    "<center><img src=\"https://ai-studio-static-online.cdn.bcebos.com/1aef9cce4dca4a8f981b8666bec825ca372a3e7efbf64784be74881152a53014\" width=\"600\" hegiht=\"\" ></center>\n",
    "<center><br>图1：数字识别任务的输入和输入不是线性关系 </br></center>\n",
    "<br></br>\n",
    "\n",
    "因此，我们需要尝试使用其他更复杂、更强大的网络来构建手写数字识别任务，观察一下训练效果，即将“横纵式”教学法从横向展开，如 **图2** 所示。本节主要介绍两种常见的网络结构：经典的多层全连接神经网络和卷积神经网络。\n",
    "\n",
    "<center><img src=\"https://ai-studio-static-online.cdn.bcebos.com/421ca2aaf93646909c1e8fd4461e1a80c0305443ef514cb9ad55e564091cd797\" width=\"800\" hegiht=\"\" ></center>\n",
    "<center><br>图2：“横纵式”教学法 — 网络结构优化 </br></center>\n",
    "<br></br>\n",
    "\n",
    "### 数据处理\n",
    "\n",
    "在介绍网络结构前，需要先进行数据处理，代码与上一节保持一致。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-03-26 15:24:28,868-INFO: font search path ['/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/ttf', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/afm', '/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/mpl-data/fonts/pdfcorefonts']\n",
      "2020-03-26 15:24:29,250-INFO: generated new fontManager\n"
     ]
    }
   ],
   "source": [
    "#数据处理部分之前的代码，保持不变\n",
    "import os\n",
    "import random\n",
    "import paddle\n",
    "import paddle.fluid as fluid\n",
    "from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from PIL import Image\n",
    "\n",
    "import gzip\n",
    "import json\n",
    "\n",
    "# 定义数据集读取器\n",
    "def load_data(mode='train'):\n",
    "\n",
    "    # 数据文件\n",
    "    datafile = './work/mnist.json.gz'\n",
    "    print('loading mnist dataset from {} ......'.format(datafile))\n",
    "    data = json.load(gzip.open(datafile))\n",
    "    train_set, val_set, eval_set = data\n",
    "\n",
    "    # 数据集相关参数，图片高度IMG_ROWS, 图片宽度IMG_COLS\n",
    "    IMG_ROWS = 28\n",
    "    IMG_COLS = 28\n",
    "\n",
    "    if mode == 'train':\n",
    "        imgs = train_set[0]\n",
    "        labels = train_set[1]\n",
    "    elif mode == 'valid':\n",
    "        imgs = val_set[0]\n",
    "        labels = val_set[1]\n",
    "    elif mode == 'eval':\n",
    "        imgs = eval_set[0]\n",
    "        labels = eval_set[1]\n",
    "\n",
    "    imgs_length = len(imgs)\n",
    "\n",
    "    assert len(imgs) == len(labels), \\\n",
    "          \"length of train_imgs({}) should be the same as train_labels({})\".format(\n",
    "                  len(imgs), len(labels))\n",
    "\n",
    "    index_list = list(range(imgs_length))\n",
    "\n",
    "    # 读入数据时用到的batchsize\n",
    "    BATCHSIZE = 100\n",
    "\n",
    "    # 定义数据生成器\n",
    "    def data_generator():\n",
    "        if mode == 'train':\n",
    "            random.shuffle(index_list)\n",
    "        imgs_list = []\n",
    "        labels_list = []\n",
    "        for i in index_list:\n",
    "            img = np.reshape(imgs[i], [1, IMG_ROWS, IMG_COLS]).astype('float32')\n",
    "            label = np.reshape(labels[i], [1]).astype('float32')\n",
    "            imgs_list.append(img) \n",
    "            labels_list.append(label)\n",
    "            if len(imgs_list) == BATCHSIZE:\n",
    "                yield np.array(imgs_list), np.array(labels_list)\n",
    "                imgs_list = []\n",
    "                labels_list = []\n",
    "\n",
    "        # 如果剩余数据的数目小于BATCHSIZE，\n",
    "        # 则剩余数据一起构成一个大小为len(imgs_list)的mini-batch\n",
    "        if len(imgs_list) > 0:\n",
    "            yield np.array(imgs_list), np.array(labels_list)\n",
    "\n",
    "    return data_generator"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "# 经典的全连接神经网络\n",
    "\n",
    "\n",
    "经典的全连接神经网络来包含四层网络：两个隐含层，输入层和输出层，将手写数字识别任务通过全连接神经网络表示，如 **图3** 所示。\n",
    "\n",
    "<center><img src=\"https://ai-studio-static-online.cdn.bcebos.com/6a66c2c4674c4a19986df9e533da76dfefed8d6ccd854a8dae4cf668055e9cca\" width=\"500\" hegiht=\"\" ></center>\n",
    "<center><br>图3：手写数字识别任务的全连接神经网络结构</br></center>\n",
    "<br></br>\n",
    "\n",
    "* 输入层：将数据输入给神经网络。在该任务中，输入层的尺度为28×28的像素值。\n",
    "* 隐含层：增加网络深度和复杂度，隐含层的节点数是可以调整的，节点数越多，神经网络表示能力越强，参数量也会增加。在该任务中，中间的两个隐含层为10×10的结构，通常隐含层会比输入层的尺寸小，以便对关键信息做抽象，激活函数使用常见的sigmoid函数。\n",
    "* 输出层：输出网络计算结果，输出层的节点数是固定的。如果是回归问题，节点数量为需要回归的数字数量。如果是分类问题，则是分类标签的数量。在该任务中，模型的输出是回归一个数字，输出层的尺寸为1。\n",
    "\n",
    "------\n",
    "**说明：**\n",
    "\n",
    "隐含层引入非线性激活函数sigmoid是为了增加神经网络的非线性能力。\n",
    "\n",
    "举例来说，如果一个神经网络采用线性变换，有四个输入$x_1$~$x_4$，一个输出$y$。假设第一层的变换是$z_1=x_1-x_2$和$z_2=x_3+x_4$，第二层的变换是$y=z_1+z_2$，则将两层的变换展开后得到$y=x_1-x_2+x_3+x_4$。也就是说，无论中间累积了多少层线性变换，原始输入和最终输出之间依然是线性关系。\n",
    " \n",
    "------\n",
    "\n",
    "Sigmoid是早期神经网络模型中常见的非线性变换函数，通过如下代码，绘制出Sigmoid的函数曲线。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAH8lJREFUeJzt3Xl4XPV59vHvM6PNkmzZsrxgS7ZsY2xsMBhkY3DCEggYkpi0IWBICJAUSt6QkpbQF5qWpLR9m6VNmrZOAyUkJSyOIaS44BRwgJAQvIONV5A3LV4kb/KidWae948ZO0KxrcEe6cyM7s916Zoz5xzN3JjRraOz/I65OyIikl1CQQcQEZHUU7mLiGQhlbuISBZSuYuIZCGVu4hIFlK5i4hkIZW7iEgWUrmLiGQhlbuISBbKCeqNy8rKvLKyMqi3FxHJSCtWrNjt7kO6Wy+wcq+srGT58uVBvb2ISEYys23JrKfdMiIiWUjlLiKShVTuIiJZSOUuIpKFVO4iIlmo23I3s0fNrMHM1hxnuZnZv5pZtZmtNrPzUh9TREQ+iGS23H8CzDrB8quB8YmvO4D/OPVYIiJyKro9z93dXzezyhOsci3wmMfv17fYzAaa2WnuviNFGUVEPjB3py0Soz0ao60jRkf091/tEScSi9ERdaIxJxKNEYklpmNONPb75zF3ojGIHZl2J+a/fx7z+HvF3HEnvixx+9JYzHE4uswTua44cxjnVAzs0f/+VFzENBKo7fS8LjHvD8rdzO4gvnXPqFGjUvDWIpJN2iJR9jd3sK+5nX2HO2hq6eBAawcHWyMcaOngUFuE5vYIh9qiHE5Mt7RHaW6P0tIRpS0So7UjSltHvNTTkRkMLynIiHJPmrs/DDwMUFVVpTtzi/QRsZiz62Ardfta2NHUyo798ceGg600Hmxj96F2Gg+2cagtcsLXKcoLU5SfQ3F+DoX5YQpzcxhYmMeIgWEKcuNf+Tmho495OaGjj3nhELnhELk5IXJDRm44RE44/hgOGblhIxwKkRMyQmbkhOOP4ZARNiMUgnBimRmEE8ss8TxkRsjA+P1zMzA6TZv1zj84qSn3eqCi0/PyxDwR6WMOtUV4b9dB3ms4RHXDITY1HGLb3mZq9jbTHnn/lnT//ByGDMhnaP98Jo8YQFlxPmXFeQwszKO0KI+BhbmU9MtlQEEuA/rlUpyfQzjUe+WY6VJR7guAu8xsHnAB0KT97SLZr6U9ylu1+3inrol36ptYU9/E1j3NR5fn5YQYW1bEuCFFfGTiUEaVFlJRWsiIkgKGlxTQvyA3wPTZr9tyN7OngEuBMjOrA74O5AK4+w+BhcA1QDXQDNzWU2FFJDitHVEWb97D4s17WbplD6vrmojE4ntXRw7sx9kjS7ju/HImDB/A+KHFVJQWaks7QMmcLXNjN8sd+FLKEolI2tjR1MKi9Q28uqGB323aTWtHjNywMaV8ILdfPJbplaWcUzGQ0qK8oKNKF4EN+Ssi6Wnf4XYWrtnBc29vZ+mWvQCMKi1kzrRRXDZxKNMrS+mXFw44pXRH5S4iuDtvbt7DY7/bxq827KIj6owbUsQ9Hz2Dq88+jXFDinr1TA85dSp3kT6spT3Kz1fW8dibW3l31yEGFuZyy4WVfHLqSCaPGKBCz2Aqd5E+qLUjyuOLt/HDX29i96F2Jo8YwLevm8Lsc0ZQkKtdLtlA5S7Sh7RFosxbWsvcV6tpONjGReMGM/em8UwfU6qt9CyjchfpI3773m4eeG4Nm3cfZvqYUv71xqnMGDs46FjSQ1TuIlluZ1Mrf//COp5fvYPRgwv58a3TuHTCEG2pZzmVu0iWcneeXlHH3y5YS0fM+coV47nzknHap95HqNxFslBTSwdf+8U7PL96BzPGlvKtT01h9OCioGNJL1K5i2SZ5Vv3cve8t9l5oJV7r5rAnZeM0zAAfZDKXSSL/HTxNr6xYC0jB/bjmTsvZOqoQUFHkoCo3EWyQDTm/MML63n0jS18ZOJQvj/nXI262Mep3EUy3OG2CHfPe4tF6xu4bWYlf/2xSdoNIyp3kUy293A7n3t0Ceu2H+DBayfzuQsrg44kaULlLpKh9h5u5zOPLGFz4yF+dMs0Lps4NOhIkkZU7iIZqHOx/+fnqrj4jCFBR5I0Ewo6gIh8MCp2SYbKXSSDHG6LcPOPVOzSPe2WEckQkWiMLz/1Fht2HuQRFbt0Q1vuIhnA3Xnw+XW8sqGBb8yerIOn0i2Vu0gG+PEbW3nszW3c/uEx3DxjdNBxJAOo3EXS3MvrdvF3L6xj1uTh3H/1mUHHkQyhchdJYzV7mvmLn73N2SNL+N4N5xLSlaeSJJW7SJpqi0S566mVmMHcm86jX57GYZfk6WwZkTT1rV9uZHVdEw/dfD4VpYVBx5EMoy13kTT00tqdPPrGFm69qJKrJg8POo5kIJW7SJqp39/Cvc+s5uyRJdx/zcSg40iGUrmLpBF356vzVxGNOf9+01Tyc7SfXU6Oyl0kjTy1tJY3N+/hax87U/c8lVOichdJEzuaWvh/C9dz0bjBzJlWEXQcyXBJlbuZzTKzjWZWbWb3HWP5KDN71czeMrPVZnZN6qOKZC9352u/WEM05nzzj6dgpvPZ5dR0W+5mFgbmAlcDk4AbzWxSl9X+Gpjv7lOBOcAPUh1UJJs99/Z2XtnQwFevmsCowTrtUU5dMlvu04Fqd9/s7u3APODaLus4MCAxXQJsT11Ekey2+1Ab3/iftZw3aiC3XlQZdBzJEslcxDQSqO30vA64oMs63wBeMrMvA0XAFSlJJ9IHfOuXGzjcFuHb103Rja0lZVJ1QPVG4CfuXg5cA/zUzP7gtc3sDjNbbmbLGxsbU/TWIplrVe1+nl5Rx+c/NIbTh/YPOo5kkWTKvR7ofOi+PDGvsy8A8wHc/U2gACjr+kLu/rC7V7l71ZAhutGA9G3uzt/+z1rKivO567LTg44jWSaZcl8GjDezMWaWR/yA6YIu69QAlwOY2ZnEy12b5iInsGDVdlbW7OcvZ02gf0Fu0HEky3Rb7u4eAe4CXgTWEz8rZq2ZPWhmsxOr3QPcbmargKeAW93deyq0SKZrbo/wjws3cPbIEq47rzzoOJKFkhoV0t0XAgu7zHug0/Q6YGZqo4lkr/94bRM7D7Qy9zNTNUa79AhdoSrSy+r2NfPQ65u59twRnD+6NOg4kqVU7iK97PuL3gPg/87SiI/Sc1TuIr1oU+Mhfr6yjs9eMJoRA/sFHUeymMpdpBd97+V3KcgN838uGxd0FMlyKneRXrJu+wGeX72D22ZWUlacH3QcyXIqd5Fe8t2XN9K/IIc7Pqytdul5KneRXrCyZh+L1jfwpxePpaRQFyxJz1O5i/SCf35pI4OL8rht5pigo0gfoXIX6WHLtu7ljeo9fPHScRTlJ3XdoMgpU7mL9LAfvFpNaVEen7lgdNBRpA9RuYv0oLXbm3h1YyOfn1lJv7xw0HGkD1G5i/Sg/3htE8X5Odx8YWXQUaSPUbmL9JAtuw+z8J0dfGbGKEr66QwZ6V0qd5Ee8tCvN5ETDvGFD+kMGel9KneRHrCzqZWfr6zj+qpyhvYvCDqO9EEqd5Ee8MhvNhNz+NOLdTWqBEPlLpJiTS0dPLm0hk9MOY2K0sKg40gfpXIXSbGfLauhuT3Kn3x4bNBRpA9TuYukUCQa479+t40LxpRy1siSoONIH6ZyF0mhF9fuon5/i86QkcCp3EVS6Ee/3cyo0kIuP3NY0FGkj1O5i6TIWzX7WFmzn9tmVhIOWdBxpI9TuYukyKNvbKV/fg6frqoIOoqIyl0kFXY0tbDwnR3cMK2CYg3rK2lA5S6SAo+9uQ1355aLKoOOIgKo3EVOWWtHlKeW1nDlpOG6aEnShspd5BQ9v3oH+5s7+NxFuhmHpA+Vu8gp+unibYwbUsSFYwcHHUXkKJW7yClYXbefVbX7uXnGaMx0+qOkD5W7yCl4fPE2+uWG+ePzy4OOIvI+KneRk9TU3MFzb2/nk1NHMqBAd1qS9JJUuZvZLDPbaGbVZnbfcda53szWmdlaM3sytTFF0s/TK2ppi8S4eYYOpEr66fZqCzMLA3OBjwJ1wDIzW+Du6zqtMx64H5jp7vvMbGhPBRZJB7GY8/jibVSNHsSkEQOCjiPyB5LZcp8OVLv7ZndvB+YB13ZZ53ZgrrvvA3D3htTGFEkvv63ezdY9zdx8obbaJT0lU+4jgdpOz+sS8zo7AzjDzN4ws8VmNutYL2Rmd5jZcjNb3tjYeHKJRdLA44u3Mbgoj1lnDQ86isgxpeqAag4wHrgUuBH4TzMb2HUld3/Y3avcvWrIkCEpemuR3rXrQCu/2tDAp6sqyM8JBx1H5JiSKfd6oPMwd+WJeZ3VAQvcvcPdtwDvEi97kazz9PJaojFnzjSN/ijpK5lyXwaMN7MxZpYHzAEWdFnnv4lvtWNmZcR302xOYU6RtBCLOU8trWXm6YOpLCsKOo7IcXVb7u4eAe4CXgTWA/Pdfa2ZPWhmsxOrvQjsMbN1wKvAve6+p6dCiwTlN9W7qd/fwo3TRwUdReSEkhp42t0XAgu7zHug07QDf5H4EslaTy6JH0i9cpIOpEp60xWqIklqONDKovUNXHd+OXk5+tGR9KZPqEiSnl5RRzTm3KADqZIBVO4iSYgfSK3hwrGDGTukOOg4It1SuYsk4bfVu6nb18KNF+hAqmQGlbtIEuYtq2FQYS5XTR4WdBSRpKjcRbqx+1AbL6/bxR+fV64rUiVjqNxFuvGLlfV0RHUgVTKLyl3kBNydectqOG/UQM4Y1j/oOCJJU7mLnMCKbfvY1HiYOdN0IFUyi8pd5ATmLaulKC/Mx6acFnQUkQ9E5S5yHAdaO3hh9Q5mnzuCovykRuoQSRsqd5Hj+J9V22npiHKDdslIBlK5ixzH/GW1TBzen3PKS4KOIvKBqdxFjmHd9gOsqmvihmkVmFnQcUQ+MJW7yDHMX15LXjjEJ8/tertgkcygchfporUjyrMr65h11nAGFeUFHUfkpKjcRbr43zU7OdAa0T1SJaOp3EW6mLeshlGlhcwYOzjoKCInTeUu0smW3YdZvHkvN0yrIBTSgVTJXCp3kU7mL68lHDKuO7886Cgip0TlLpLQEY3xzIo6LpswlGEDCoKOI3JKVO4iCa9uaKDxYJsOpEpWULmLJPxsWS1D++dz6YQhQUcROWUqdxFgZ1Mrr25s4NNV5eSE9WMhmU+fYhHiB1JjDtdXaZeMZAeVu/R50Zjzs2W1fOj0MkYPLgo6jkhKqNylz3v9vUbq97dw43QN7SvZQ+Uufd5TS2ooK87jo5OGBR1FJGVU7tKn7TrQyq82NHDd+RXk5ejHQbKHPs3Spz29vJZozHVuu2SdpMrdzGaZ2UYzqzaz+06w3qfMzM2sKnURRXpGLOY8tbSWmacPprJMB1Ilu3Rb7mYWBuYCVwOTgBvNbNIx1usP3A0sSXVIkZ7wm+rdOpAqWSuZLffpQLW7b3b3dmAecO0x1vs74FtAawrzifSYJ5dsY3BRHldOGh50FJGUS6bcRwK1nZ7XJeYdZWbnARXu/kIKs4n0mF0HWlm0voFPnV+uA6mSlU75U21mIeC7wD1JrHuHmS03s+WNjY2n+tYiJ+2ppTVEY85N2iUjWSqZcq8HOp9KUJ6Yd0R/4CzgNTPbCswAFhzroKq7P+zuVe5eNWSIBmeSYHREYzy1tIZLzhiiA6mStZIp92XAeDMbY2Z5wBxgwZGF7t7k7mXuXunulcBiYLa7L++RxCKnaNG6Xew60MbNM0YHHUWkx3Rb7u4eAe4CXgTWA/Pdfa2ZPWhms3s6oEiq/XTxNkYO7MdlE4cGHUWkx+Qks5K7LwQWdpn3wHHWvfTUY4n0jOqGg/xu0x7uvWoCYd0jVbKYThOQPuXxxTXkhUPcoCtSJcup3KXPaG6P8PMVdVxz9nDKivODjiPSo1Tu0mc89/Z2DrZFuPlCHUiV7Kdylz7B3XnszW2cedoAzhs1KOg4Ij1O5S59wuLNe1m/4wC3XDgaMx1Ileyncpc+4dE3tlBalMcnp47sfmWRLKByl6y3dfdhFq3fxWcuGEVBbjjoOCK9QuUuWe8nv9tKTsh0Rar0KSp3yWoHWjt4enktn5gygqEDCoKOI9JrVO6S1eYvq+Vwe5TPf2hM0FFEepXKXbJWJBrjx29sZfqYUs4aWRJ0HJFepXKXrPXyul3U72/hC9pqlz5I5S5Zyd155LdbqCjtxxVnDgs6jkivU7lLVlq6ZS8rtu3j9g+P1eiP0iep3CUrzX1tE2XFeVxfpdEfpW9SuUvWWVPfxOvvNvL5D43RRUvSZ6ncJev84LVq+ufn8FldtCR9mMpdssqmxkP8cs1OPnfRaAYU5AYdRyQwKnfJKg/9ehN54RC3zdTpj9K3qdwla2zf38KzK+uZM61Cd1qSPk/lLlnj4dc3A3D7xWMDTiISPJW7ZIX6/S08uaSGT51XTvmgwqDjiARO5S5Z4d9+9R4Af3bF+ICTiKQHlbtkvC27D/P0ijpuumAUIwf2CzqOSFpQuUvG+5dF75IXDvGly04POopI2lC5S0bbuPMgC1Zt59aZlQzprzNkRI5QuUtG++eXNlKcl8Of6gwZkfdRuUvGWlW7n5fW7eL2i8cysDAv6DgiaUXlLhnJ3fn7F9YxuChPt9ATOQaVu2Sk51fvYNnWfXz1qgkU5+cEHUck7ajcJeO0tEf5x4XrmTxigMZrFzmOpMrdzGaZ2UYzqzaz+46x/C/MbJ2ZrTazX5mZxlqVHvPQ65vY3tTK1z8xWXdZEjmObsvdzMLAXOBqYBJwo5lN6rLaW0CVu08BngG+neqgIhAfZuCHv97Ex6ecxvQxpUHHEUlbyWy5Tweq3X2zu7cD84BrO6/g7q+6e3Pi6WKgPLUxReL+ceF63OH+a84MOopIWkum3EcCtZ2e1yXmHc8XgF8ea4GZ3WFmy81seWNjY/IpRYA3N+3h+dU7uPOScRpmQKQbKT2gamafBaqA7xxrubs/7O5V7l41ZMiQVL61ZLmW9ij3P7uaUaWF3HnJuKDjiKS9ZM4hqwc6n5JQnpj3PmZ2BfA14BJ3b0tNPJG47y16l617mnny9gvol6ebXot0J5kt92XAeDMbY2Z5wBxgQecVzGwq8BAw290bUh9T+rJVtft55DebuXH6KC4aVxZ0HJGM0G25u3sEuAt4EVgPzHf3tWb2oJnNTqz2HaAYeNrM3jazBcd5OZEPpD0S4y+fWc3Q/gXcf83EoOOIZIykLu1z94XAwi7zHug0fUWKc4kAMPfVajbuOsijt1YxoCA36DgiGUNXqEraWlPfxA9eq+aT547gIxOHBR1HJKOo3CUtHWzt4EtPrqSsOJ+vf2Jy0HFEMo5GXJK04+7c9+w71O1r4Wd3zGBQkYbzFfmgtOUuaeeJJTW8sHoH91x5BlWVGmJA5GSo3CWtrN3exIPPr+OSM4Zw58W6WEnkZKncJW00NXdw15NvMagwl+9efw4hjfgoctK0z13SQnskxp2Pr6BuXzNP/MkMBhfrZtcip0LlLoFzd/7qF+/w5uY9fO+GczSUr0gKaLeMBG7uq9U8s6KOuy8fzx9N1WjRIqmgcpdAPfd2Pf/00rv80dSRfOWK8UHHEckaKncJzItrd3LP/FVMH1PKNz91NmY6gCqSKip3CcSLa3fypSdWctbIEh65pYr8HA3jK5JKKnfpdUeK/ezyEh77wnQNCCbSA3S2jPSq/12zg7uefIuzy0v4r8+r2EV6ispdeoW786PfbuEfFq5nasVAfqJiF+lRKnfpcR3RGF9fsJYnl9Rw9VnD+e715+pWeSI9TOUuPepAawdfemIlv3lvN1+8dBz3XjlBwwqI9AKVu/SYlTX7uHveW+zY38q3PzWF66dVdP9NIpISKndJuWjM+eGvN/Hdl99l+IAC5t0xQ0P3ivQylbukVN2+Zr769CoWb97Lx6ecxj/80dmU9NOBU5HepnKXlGiLRHnkN1v4t1feI2TGd66bwnXnl+uqU5GAqNzllL1RvZu/eW4NmxsPM2vycP7mE5MYObBf0LFE+jSVu5y0VbX7+d6id3ltYyOjBxfyk9umcemEoUHHEhFU7nIS1tQ38S+L3mXR+gYGFeZy39UTufWiSgpyde66SLpQuUtSItEYi9bv4ie/28rizXsp6ZfLvVdN4JaLKinO18dIJN3op1JOqHZvM8+9Xc+TS2rY3tTKyIH9uO/qidx0wSgNHyCSxlTu8gd2H2rjl+/s4Lm3t7N82z4AZp4+mG/MnszlZw4jrCtMRdKeyl1wd9ZuP8ArGxp4ZUMDq+r24w5nDCvm3qsmMPucEVSUFgYdU0Q+AJV7HxSJxtiw8yBLt+yNf23dy97D7ZjBlPKBfOXyM7hy8jDOPG1A0FFF5CSp3LNcS3uU6oZDbNh5gDX1TbxT38S6HQdo7YgBUFHaj8smDOXCcYO5dMIQyorzA04sIqmQVLmb2Szg+0AYeMTdv9lleT7wGHA+sAe4wd23pjaqHM/htgg7mlqo2dvMtj3NRx/fazhI3b4W3OPrFeaFOWtECTdNH805FSVMqyxlhC42EslK3Za7mYWBucBHgTpgmZktcPd1nVb7ArDP3U83sznAt4AbeiJwXxCLOYfaIzQ1d7CvuZ19zR3sO9zOnsPtNB5sY/ehNhoPtrHrQCvb97dwoDXyvu8vzAszqrSQcysG8enzKxg/tJjxw/ozpqxIB0NF+ohkttynA9XuvhnAzOYB1wKdy/1a4BuJ6WeAfzczcz+yzZhZ3J1ozInEnFhi+sjzSNTpiMYS0zE6Es87ojHaozHaI4mvxHRbJEZrR5TWjvhjS0eUlvYoze1RWjoiHGqLcrgtwuG2CAdbIxxs7eBgW4Tj/cvlho2y4nzKivMpH1TI9DGlnFbSj9NKCqgoLWT04EIGF+VpTBeRPi6Zch8J1HZ6XgdccLx13D1iZk3AYGB3KkJ2Nn9ZLQ+9vgkHSBSgAzF33MFxYrHEfPejy2JOfLk7UXdisfj6UY8XeCwWXy+aeJ2ekp8TojAvTL/cMP3ywhTn51BckMPgokKK83MY0C83/lUQnx5UmMegwlwGFuZRVpxHSb9cFbeIdKtXD6ia2R3AHQCjRo06qdcYVJTHxOEDwMDirwlAqNPzo4+JeeFQYjqxLBwyQnbk68hyIxzi6PyckBEKGeFQfPr3jyFywkZu2MgJhcgNh8jL6TwdIj8n/pgXDpGfG6IgJ0xBbpj8nJDuQiQivSKZcq8HOt9Cpzwx71jr1JlZDlBC/MDq+7j7w8DDAFVVVSe1ffzRScP46KRhJ/OtIiJ9RiiJdZYB481sjJnlAXOABV3WWQDckpi+DnglU/e3i4hkg2633BP70O8CXiR+KuSj7r7WzB4Elrv7AuBHwE/NrBrYS/wXgIiIBCSpfe7uvhBY2GXeA52mW4FPpzaaiIicrGR2y4iISIZRuYuIZCGVu4hIFlK5i4hkIZW7iEgWsqBORzezRmDbSX57GT0wtEGKpGu2dM0F6ZstXXNB+mZL11yQPdlGu/uQ7lYKrNxPhZktd/eqoHMcS7pmS9dckL7Z0jUXpG+2dM0FfS+bdsuIiGQhlbuISBbK1HJ/OOgAJ5Cu2dI1F6RvtnTNBembLV1zQR/LlpH73EVE5MQydctdREROIGPL3czONbPFZva2mS03s+lBZ+rMzL5sZhvMbK2ZfTvoPJ2Z2T1m5mZWFnSWI8zsO4l/r9Vm9gszGxhwnllmttHMqs3sviCzHGFmFWb2qpmtS3yu7g46U1dmFjazt8zs+aCzdGZmA83smcRnbL2ZXRh0JgAz+/PE/8s1ZvaUmRWk6rUzttyBbwN/6+7nAg8knqcFM7uM+H1lz3H3ycA/BRzpKDOrAK4EaoLO0sXLwFnuPgV4F7g/qCCdbgp/NTAJuNHMJgWVp5MIcI+7TwJmAF9Kk1yd3Q2sDzrEMXwf+F93nwicQxpkNLORwJ8BVe5+FvEh1VM2XHoml7sDAxLTJcD2ALN09UXgm+7eBuDuDQHn6ex7wF9y9A606cHdX3L3SOLpYuJ3/ArK0ZvCu3s7cOSm8IFy9x3uvjIxfZB4QY0MNtXvmVk58DHgkaCzdGZmJcDFxO87gbu3u/v+YFMdlQP0S9zBrpAU9lgml/tXgO+YWS3xLePAtvSO4Qzgw2a2xMx+bWbTgg4EYGbXAvXuviroLN34PPDLAN//WDeFT5sSBTCzSmAqsCTYJO/zL8Q3HGJBB+liDNAI/Dixy+gRMysKOpS71xPvrhpgB9Dk7i+l6vV79QbZH5SZLQKGH2PR14DLgT9395+b2fXEfytfkSbZcoBS4n86TwPmm9nY3rj1YDe5/or4LplAnCibuz+XWOdrxHc/PNGb2TKJmRUDPwe+4u4Hgs4DYGYfBxrcfYWZXRp0ni5ygPOAL7v7EjP7PnAf8DdBhjKzQcT/IhwD7AeeNrPPuvvjqXj9tC53dz9uWZvZY8T37wE8TS//KdhNti8CzybKfKmZxYiPHdEYVC4zO5v4h2iVmUF8t8dKM5vu7jt7OteJsh1hZrcCHwcuD/gevMncFD4QZpZLvNifcPdng87TyUxgtpldAxQAA8zscXf/bMC5IP6XV527H/kr5xni5R60K4At7t4IYGbPAhcBKSn3TN4tsx24JDH9EeC9ALN09d/AZQBmdgaQR8ADFrn7O+4+1N0r3b2S+Af+vN4q9u6Y2Szif9LPdvfmgOMkc1P4Xmfx38o/Ata7+3eDztOZu9/v7uWJz9Yc4JU0KXYSn/FaM5uQmHU5sC7ASEfUADPMrDDx//ZyUnigN6233LtxO/D9xIGIVuCOgPN09ijwqJmtAdqBWwLeEs0E/w7kAy8n/rJY7O53BhHkeDeFDyJLFzOBm4F3zOztxLy/StzjWE7sy8ATiV/Wm4HbAs5DYhfRM8BK4rsi3yKFV6rqClURkSyUybtlRETkOFTuIiJZSOUuIpKFVO4iIllI5S4ikoVU7iIiWUjlLiKShVTuIiJZ6P8DtkRmteUBWjsAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "def sigmoid(x):\n",
    "    # 直接返回sigmoid函数\n",
    "    return 1. / (1. + np.exp(-x))\n",
    " \n",
    "# param:起点，终点，间距\n",
    "x = np.arange(-8, 8, 0.2)\n",
    "y = sigmoid(x)\n",
    "plt.plot(x, y)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "针对手写数字识别的任务，网络层的设计如下：\n",
    "\n",
    "* 输入层的尺度为28×28，但批次计算的时候会统一加1个维度（大小为bitchsize）。\n",
    "* 中间的两个隐含层为10×10的结构，激活函数使用常见的sigmoid函数。\n",
    "* 与房价预测模型一样，模型的输出是回归一个数字，输出层的尺寸设置成1。\n",
    "\n",
    "下述代码为经典全连接神经网络的实现。完成网络结构定义后，即可训练神经网络。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# 多层全连接神经网络实现\n",
    "class MNIST(fluid.dygraph.Layer):\n",
    "    def __init__(self, name_scope):\n",
    "        super(MNIST, self).__init__(name_scope)\n",
    "        # 定义两层全连接隐含层，输出维度是10，激活函数为sigmoid\n",
    "        self.fc1 = Linear(input_dim=784, output_dim=10, act='sigmoid') # 隐含层节点为10，可根据任务调整\n",
    "        self.fc2 = Linear(input_dim=10, output_dim=10, act='sigmoid')\n",
    "        # 定义一层全连接输出层，输出维度是1，不使用激活函数\n",
    "        self.fc3 = Linear(input_dim=10, output_dim=1, act=None)\n",
    "    \n",
    "    # 定义网络的前向计算\n",
    "    def forward(self, inputs, label=None):\n",
    "        inputs = fluid.layers.reshape(inputs, [inputs.shape[0], 784])\n",
    "        outputs1 = self.fc1(inputs)\n",
    "        outputs2 = self.fc2(outputs1)\n",
    "        outputs_final = self.fc3(outputs2)\n",
    "        return outputs_final"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loading mnist dataset from ./work/mnist.json.gz ......\n",
      "epoch: 0, batch: 0, loss is: [27.740425]\n",
      "epoch: 0, batch: 200, loss is: [5.4588423]\n",
      "epoch: 0, batch: 400, loss is: [3.9063952]\n",
      "epoch: 1, batch: 0, loss is: [3.8620145]\n",
      "epoch: 1, batch: 200, loss is: [4.6423216]\n",
      "epoch: 1, batch: 400, loss is: [3.9099925]\n",
      "epoch: 2, batch: 0, loss is: [3.3493927]\n",
      "epoch: 2, batch: 200, loss is: [2.8054562]\n",
      "epoch: 2, batch: 400, loss is: [2.8475616]\n",
      "epoch: 3, batch: 0, loss is: [3.1059093]\n",
      "epoch: 3, batch: 200, loss is: [2.8764062]\n",
      "epoch: 3, batch: 400, loss is: [2.248354]\n",
      "epoch: 4, batch: 0, loss is: [2.3325133]\n",
      "epoch: 4, batch: 200, loss is: [2.9140906]\n",
      "epoch: 4, batch: 400, loss is: [1.6771106]\n"
     ]
    }
   ],
   "source": [
    "#网络结构部分之后的代码，保持不变\n",
    "with fluid.dygraph.guard():\n",
    "    model = MNIST(\"mnist\")\n",
    "    model.train()\n",
    "    #调用加载数据的函数，获得MNIST训练数据集\n",
    "    train_loader = load_data('train')\n",
    "    # 使用SGD优化器，learning_rate设置为0.01\n",
    "    optimizer = fluid.optimizer.SGDOptimizer(learning_rate=0.01, parameter_list=model.parameters())\n",
    "    # 训练5轮\n",
    "    EPOCH_NUM = 5\n",
    "    for epoch_id in range(EPOCH_NUM):\n",
    "        for batch_id, data in enumerate(train_loader()):\n",
    "            #准备数据\n",
    "            image_data, label_data = data\n",
    "            image = fluid.dygraph.to_variable(image_data)\n",
    "            label = fluid.dygraph.to_variable(label_data)\n",
    "            \n",
    "            #前向计算的过程\n",
    "            predict = model(image)\n",
    "            \n",
    "            #计算损失，取一个批次样本损失的平均值\n",
    "            loss = fluid.layers.square_error_cost(predict, label)\n",
    "            avg_loss = fluid.layers.mean(loss)\n",
    "            \n",
    "            #每训练了200批次的数据，打印下当前Loss的情况\n",
    "            if batch_id % 200 == 0:\n",
    "                print(\"epoch: {}, batch: {}, loss is: {}\".format(epoch_id, batch_id, avg_loss.numpy()))\n",
    "            \n",
    "            #后向传播，更新参数的过程\n",
    "            avg_loss.backward()\n",
    "            optimizer.minimize(avg_loss)\n",
    "            model.clear_gradients()\n",
    "\n",
    "    #保存模型参数\n",
    "    fluid.save_dygraph(model.state_dict(), 'mnist')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "# 卷积神经网络\n",
    "\n",
    "虽然使用经典的神经网络可以提升一定的准确率，但对于计算机视觉问题，效果最好的模型仍然是卷积神经网络。卷积神经网络针对视觉问题的特点进行了网络结构优化，更适合处理视觉问题。\n",
    "\n",
    "卷积神经网络由多个卷积层和池化层组成，如 **图4** 所示。卷积层负责对输入进行扫描以生成更抽象的特征表示，池化层对这些特征表示进行过滤，保留最关键的特征信息。\n",
    "<center><img src=\"https://ai-studio-static-online.cdn.bcebos.com/ae78212c521e4814a07714a613ce1fe90832de104d7c46ed93fedb7bce18f6a0\" width=\"700\" hegiht=\"\" ></center>\n",
    "<center><br>图4：在处理计算机视觉任务中大放异彩的卷积神经网络</br></center>\n",
    "<br></br>\n",
    "\n",
    "-------\n",
    "**说明：**\n",
    "\n",
    "本节只介绍手写数字识别在卷积神经网络的实现以及它带来的效果提升。读者可以将卷积神经网络先简单的理解成是一种比经典的全连接神经网络更强大的模型即可，更详细的原理和实现在接下来的[第四章-计算机视觉-卷积神经网络基础](https://aistudio.baidu.com/aistudio/projectdetail/236887)中讲述。\n",
    "\n",
    "------\n",
    "\n",
    "两层卷积和池化的神经网络实现如下代码所示。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# 多层卷积神经网络实现\n",
    "class MNIST(fluid.dygraph.Layer):\n",
    "     def __init__(self, name_scope):\n",
    "         super(MNIST, self).__init__(name_scope)\n",
    "         \n",
    "         # 定义卷积层，输出特征通道num_filters设置为20，卷积核的大小filter_size为5，卷积步长stride=1，padding=2\n",
    "         # 激活函数使用relu\n",
    "         self.conv1 = Conv2D(num_channels=1, num_filters=20, filter_size=5, stride=1, padding=2, act='relu')\n",
    "         # 定义池化层，池化核pool_size=2，池化步长为2，选择最大池化方式\n",
    "         self.pool1 = Pool2D(pool_size=2, pool_stride=2, pool_type='max')\n",
    "         # 定义卷积层，输出特征通道num_filters设置为20，卷积核的大小filter_size为5，卷积步长stride=1，padding=2\n",
    "         self.conv2 = Conv2D(num_channels=20, num_filters=20, filter_size=5, stride=1, padding=2, act='relu')\n",
    "         # 定义池化层，池化核pool_size=2，池化步长为2，选择最大池化方式\n",
    "         self.pool2 = Pool2D(pool_size=2, pool_stride=2, pool_type='max')\n",
    "         # 定义一层全连接层，输出维度是1，不使用激活函数\n",
    "         self.fc = Linear(input_dim=980, output_dim=1, act=None)\n",
    "         \n",
    "    # 定义网络前向计算过程，卷积后紧接着使用池化层，最后使用全连接层计算最终输出\n",
    "     def forward(self, inputs):\n",
    "         x = self.conv1(inputs)\n",
    "         x = self.pool1(x)\n",
    "         x = self.conv2(x)\n",
    "         x = self.pool2(x)\n",
    "         x = fluid.layers.reshape(x, [x.shape[0], -1])\n",
    "         x = self.fc(x)\n",
    "         return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "训练定义好的卷积神经网络，代码如下所示。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loading mnist dataset from ./work/mnist.json.gz ......\n",
      "epoch: 0, batch: 0, loss is: [31.675833]\n",
      "epoch: 0, batch: 200, loss is: [9.248349]\n",
      "epoch: 0, batch: 400, loss is: [3.2532346]\n",
      "epoch: 1, batch: 0, loss is: [2.5735705]\n",
      "epoch: 1, batch: 200, loss is: [2.7086043]\n",
      "epoch: 1, batch: 400, loss is: [2.351327]\n",
      "epoch: 2, batch: 0, loss is: [2.2003784]\n",
      "epoch: 2, batch: 200, loss is: [2.53069]\n",
      "epoch: 2, batch: 400, loss is: [2.154322]\n",
      "epoch: 3, batch: 0, loss is: [1.8227897]\n",
      "epoch: 3, batch: 200, loss is: [1.8546791]\n",
      "epoch: 3, batch: 400, loss is: [2.3879793]\n",
      "epoch: 4, batch: 0, loss is: [2.6370738]\n",
      "epoch: 4, batch: 200, loss is: [1.6437341]\n",
      "epoch: 4, batch: 400, loss is: [1.6468849]\n"
     ]
    }
   ],
   "source": [
    "#网络结构部分之后的代码，保持不变\n",
    "with fluid.dygraph.guard():\n",
    "    model = MNIST(\"mnist\")\n",
    "    model.train()\n",
    "    #调用加载数据的函数\n",
    "    train_loader = load_data('train')\n",
    "    optimizer = fluid.optimizer.SGDOptimizer(learning_rate=0.01, parameter_list=model.parameters())\n",
    "    EPOCH_NUM = 5\n",
    "    for epoch_id in range(EPOCH_NUM):\n",
    "        for batch_id, data in enumerate(train_loader()):\n",
    "            #准备数据\n",
    "            image_data, label_data = data\n",
    "            image = fluid.dygraph.to_variable(image_data)\n",
    "            label = fluid.dygraph.to_variable(label_data)\n",
    "             \n",
    "            #前向计算的过程\n",
    "            predict = model(image)\n",
    "            \n",
    "            #计算损失，取一个批次样本损失的平均值\n",
    "            loss = fluid.layers.square_error_cost(predict, label)\n",
    "            avg_loss = fluid.layers.mean(loss)\n",
    "            \n",
    "            #每训练了100批次的数据，打印下当前Loss的情况\n",
    "            if batch_id % 200 == 0:\n",
    "                print(\"epoch: {}, batch: {}, loss is: {}\".format(epoch_id, batch_id, avg_loss.numpy()))\n",
    "            \n",
    "            #后向传播，更新参数的过程\n",
    "            avg_loss.backward()\n",
    "            optimizer.minimize(avg_loss)\n",
    "            model.clear_gradients()\n",
    "\n",
    "    #保存模型参数\n",
    "    fluid.save_dygraph(model.state_dict(), 'mnist')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "比较经典全连接神经网络和卷积神经网络的损失变化，可以发现卷积神经网络的损失值下降更快，且最终的损失值更小。\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "PaddlePaddle 1.7.0 (Python 3.5)",
   "language": "python",
   "name": "py35-paddle1.2.0"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
